from transformers.models.llama import LlamaModel, LlamaConfig, LlamaTokenizer
import torch

def run_llama():
    llamaconfig = LlamaConfig(vocab_size=32000, hidden_size=4096//2,
                              intermediate_size=11008//2,
                              num_hidden_layers=32//2,
                              num_attention_heads=32//2,
                              max_position_embeddings=2048//2)

    llamamodel = LlamaModel(config=llamaconfig)

    input_ids = torch.randint(
        low=0, high=llamaconfig.vocab_size, size=(4, 30))

    input_ids = torch.randint(
        low=0, high=llamaconfig.vocab_size, size=(4, 30))

    res = llamamodel(input_ids)
    print(res)

if  __name__ == '__main__':
    run_llama()